# Copyright 2025 The TensorStore Authors
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Generates the .pyi type stubs for TensorStore."""

import argparse
import ast
import difflib
import importlib
import importlib.resources
import os
import pathlib
import sys
import tempfile

import black
import pybind11_stubgen
import pybind11_stubgen.parser.mixins.fix


# Some required DLLs may be present in the PATH rather than in the system
# directory or other search paths, so expand the DLL paths for testing.
if hasattr(os, "add_dll_directory"):
  env_value = os.environ.get("PATH")
  path_list = env_value.split(os.pathsep) if env_value is not None else []
  for prefix_path in path_list:
    # Only add directories that exist
    if os.path.isdir(prefix_path):
      os.add_dll_directory(os.path.abspath(prefix_path))


def _monkey_patch_fix_numpy_dtype():
  """Fixes a bug in pybind11_stubgen where any type named `dtype` is assumed to be `numpy.dtype`."""
  FixNumpyDtype = pybind11_stubgen.parser.mixins.fix.FixNumpyDtype  # pylint: disable=invalid-name
  orig_parse_annotation_str = FixNumpyDtype.parse_annotation_str

  def parse_annotation_str(self, annotation_str: str):
    if "tensorstore" in annotation_str:
      return super(FixNumpyDtype, self).parse_annotation_str(annotation_str)
    return orig_parse_annotation_str(self, annotation_str)

  setattr(FixNumpyDtype, "parse_annotation_str", parse_annotation_str)


_monkey_patch_fix_numpy_dtype()

_STUB_EXTRA_CONTENT = """
import asyncio
T_contra = typing.TypeVar("T_contra", contravariant=True)
"""


def _is_symbol_def(node: ast.AST) -> str | None:
  """Determines if the given AST node defines a symbol."""
  match node:
    case ast.FunctionDef(name):
      return name
    case ast.ClassDef(name):
      return name
    case ast.Assign(targets=[ast.Name(id=_ as name)]):
      return name
    case ast.AnnAssign(target=ast.Name(id=_ as name)):
      return name
    case _:
      return None


def _get_defined_symbols(tree: ast.Module) -> set[str]:
  """Returns the set of symbols defined in the given Python module AST."""
  symbols: set[str] = set()
  for node in tree.body:
    symbol = _is_symbol_def(node)
    if symbol is not None:
      symbols.add(symbol)
  return symbols


def _strip_module_prefix(tree: ast.AST, module_prefix: str) -> ast.AST:
  """Strips the module prefix from Attribute nodes in the given AST."""

  class _Transformer(ast.NodeTransformer):

    def visit_Attribute(self, node: ast.Attribute):  # pylint: disable=invalid-name
      match node.value:
        case ast.Name(id=_ as name) if name == module_prefix:
          return ast.Name(id=node.attr)
      return self.generic_visit(node)

  return _Transformer().visit(tree)


def _remove_duplicate_imports(body: list[ast.AST]) -> list[ast.AST]:
  """Removes duplicate imports from the given AST."""
  new_body: list[ast.AST] = []
  seen_imports: set[str] = set()
  for node in body:
    match node:
      case ast.Import(names=[name]) | ast.ImportFrom(names=[name]):
        asname = name.asname or name.name
        if asname not in seen_imports:
          new_body.append(node)
          seen_imports.add(asname)
      case _:
        new_body.append(node)
  return new_body


def transform_init_ast(
    init_py_content: str,
    init_pyi_tree: ast.Module,
    submodule_names: set[str],
) -> ast.AST:
  """Munges the `__init__.pyi` AST generated by pybind11_stubgen.

  The way that tensorstore renames all symbols that are actually defined in the
  `_tensorstore` module to appear to be defined in `tensorstore` confuses
  pybind11_stubgen.

  Args:
    init_py_content: The content of the `__init__.py` file.
    init_pyi_tree: The AST of the `__init__.pyi` file generated by
      pybind11_stubgen.
    submodule_names: The set of submodule names.

  Returns:
    The modified `__init__.pyi` AST.
  """
  init_py_tree = ast.parse(init_py_content)

  # Remove initial doc comment from init_pyi_tree
  match init_pyi_tree.body:
    case [ast.Expr(value=ast.Constant(value=str())), *_]:
      del init_pyi_tree.body[0]

  # Exclude any definitions of symbols (including class members) by these names.
  #
  # These are internal implementation details that aren't useful in the type
  # stubs.
  excluded_symbols = {
      "_unpickle",
      "__reduce__",
      "__getstate__",
      "__setstate__",
      "__conditional_annotations__",
      "__class_getitem__",
      "__del__",
      "__type_params__",
  }

  class _InitPyVisitor(ast.NodeTransformer):

    def visit_AnnAssign(self, node: ast.AnnAssign):  # pylint: disable=invalid-name
      match node.target:
        case ast.Name(id=name) if name in excluded_symbols:
          return None
        case _:
          return node

    def visit_FunctionDef(self, node: ast.FunctionDef):  # pylint: disable=invalid-name
      return None if node.name in excluded_symbols else node

  init_pyi_tree = _InitPyVisitor().visit(init_pyi_tree)

  # Filter out statements from `init_py` that shouldn't be included in
  # type stubs.
  def _include_py_node(node: ast.AST) -> bool:
    match node:
      case ast.Delete():
        # Excludes the `del ...` statements in `__init__.py`.
        return False
      case ast.Expr(value):
        match value:
          case ast.Constant():
            # Include doc comments.
            return True
          case _:
            return False
      case ast.ImportFrom(module="_tensorstore"):
        # Excludes the `from _tensorstore import ...` statements.  The
        # `_tensorstore` module isn't intended to be exposed in the type stubs.
        return False
      case _:
        return True

  excluded_top_level_symbols = _get_defined_symbols(init_py_tree)
  excluded_top_level_symbols.add("_Decodable")

  # Filter out statements from `init_pyi` that shouldn't be included in the type
  # stubs.
  def _include_pyi_node(node: ast.AST) -> bool:
    match node:
      case ast.ImportFrom(names=[ast.alias(name="_tensorstore")]):
        # Excludes the `from _tensorstore import ...` statements.  The
        # `_tensorstore` module isn't intended to be exposed in the type stubs.
        return False
      case _:
        symbol = _is_symbol_def(node)
        if symbol is None:
          return True
        # Exclude symbols that are already defined in `__init__.py`.  The
        # definition from `__init__.py` is used instead.
        return symbol not in excluded_top_level_symbols

  new_body: list[ast.AST] = []

  module_doc_comment_node: ast.AST | None = None

  match init_py_tree.body:
    case [ast.Expr(value=ast.Constant(value=str())) as node, *_]:
      module_doc_comment_node = node
      del init_py_tree.body[0]

  new_body.extend(ast.parse(_STUB_EXTRA_CONTENT).body)
  new_body.extend(node for node in init_py_tree.body if _include_py_node(node))
  new_body.extend(
      node for node in init_pyi_tree.body if _include_pyi_node(node)
  )

  def _rename_identifier(node: ast.AST, old_id: str, new_id: str):
    """Renames an identifier in the given AST node."""

    class Visitor(ast.NodeVisitor):

      def visit_Name(self, n: ast.Name):  # pylint: disable=invalid-name
        if n.id == old_id:
          n.id = new_id

    Visitor().visit(node)

  for node in new_body:
    match node:
      # Modify class definitions.
      case ast.ClassDef(name=name) as class_node:
        match name:
          case "Future":
            # Add Generic base class.
            class_node.bases.append(
                ast.parse("typing.Generic[T]", mode="eval").body
            )
          case "Promise":
            # Add Generic base class and rename `T` to `T_contra` to have
            # correct contravariance.
            class_node.bases.append(
                ast.parse("typing.Generic[T_contra]", mode="eval").body
            )
            _rename_identifier(class_node, "T", "T_contra")
          case "IndexTransform" | "Spec" | "TensorStore":
            # Add Indexable base class.
            class_node.bases.append(ast.parse("Indexable", mode="eval").body)
          case "IndexDomain":
            # Add explicit `__iter__` method to avoid mypy sees it as iterable.
            class_node.body.extend(
                ast.parse(
                    "def __iter__(self) -> collections.abc.Iterator[Dim]: ..."
                ).body
            )
          case "OutputIndexMaps":
            # Add explicit `__iter__` method to ensure mypy sees it as iterable.
            class_node.body.extend(
                ast.parse(
                    "def __iter__(self) ->"
                    " collections.abc.Iterator[OutputIndexMap]: ..."
                ).body
            )
      # Remove submodules from `__all__`.
      case ast.Assign(targets=[ast.Name(id="__all__")], value=value):
        match value:
          case ast.List(elts=elts):
            value.elts = [
                elt
                for elt in elts
                if isinstance(elt, ast.Constant)
                and elt.value not in submodule_names
            ]

  def _sort_order(node: ast.AST) -> int:
    match node:
      case ast.ImportFrom(module="__future__"):
        # __future__ imports should be first.
        return -1
      case ast.ImportFrom() | ast.Import():
        # Other imports are next.
        return 0
      case _:
        return 1

  new_body.sort(key=_sort_order)
  new_body = _remove_duplicate_imports(new_body)

  if module_doc_comment_node is not None:
    new_body.insert(0, module_doc_comment_node)

  init_pyi_tree.body = new_body

  # Some annotations include `tensorstore.` prefix, which is incorrect.
  init_pyi_tree = _strip_module_prefix(init_pyi_tree, "tensorstore")

  return init_pyi_tree


def _fix_all(tree: ast.Module) -> ast.Module:
  """Excludes the Literal type annotation on `__all__` added by pybind11_stubgen.

  The type annotation breaks type checkers.

  Args:
    tree: The AST to modify.

  Returns:
    The modified AST.
  """
  for i, node in enumerate(tree.body):
    match node:
      case ast.AnnAssign(target=ast.Name(id="__all__")):
        tree.body[i] = ast.Assign(
            targets=[node.target], value=node.value, type_comment=None
        )
        return tree
  return tree


def _strip_generic_slice(tree: ast.Module) -> ast.Module:
  """Strips type parameters from `slice`.

  These are supported by mypy but not by pytype.

  Args:
    tree: The AST to modify.

  Returns:
    The modified AST.
  """

  class _Transformer(ast.NodeTransformer):

    def visit_Subscript(self, node: ast.Subscript):  # pylint: disable=invalid-name
      match node.value:
        case ast.Name(id="slice") as name:
          return name
      return self.generic_visit(node)

    def visit_Constant(self, node: ast.Constant):  # pylint: disable=invalid-name
      if isinstance(node.value, str) and node.value.startswith("slice["):
        node.value = "slice"
      return node

  return _Transformer().visit(tree)


def _munge_type_stubs_file(
    input_path: pathlib.Path,
    strip_generic_slice: bool,
    submodule_names: set[str],
) -> str:
  """Munges the type stubs file."""
  content = input_path.read_text(encoding="utf-8")
  stub_ast = ast.parse(content)
  _fix_all(stub_ast)
  if input_path.name == "__init__.pyi":
    init_py_path = pathlib.Path(importlib.import_module("tensorstore").__file__)
    stub_ast = transform_init_ast(
        init_py_content=init_py_path.read_text(encoding="utf-8"),
        init_pyi_tree=stub_ast,
        submodule_names=submodule_names,
    )
  if strip_generic_slice:
    stub_ast = _strip_generic_slice(stub_ast)
  ast.fix_missing_locations(stub_ast)
  content = ast.unparse(stub_ast)
  content = content.replace("typing.Tuple[", "tuple[")
  content = black.format_str(content, mode=black.Mode())
  return content


def main():
  ap = argparse.ArgumentParser()
  ap.add_argument(
      "--output-dir",
      type=pathlib.Path,
      default=pathlib.Path(os.path.dirname(os.path.realpath(__file__))),
      help=(
          "Output paths for the generated type stubs.  Defaults to source dir."
      ),
  )
  ap.add_argument(
      "--generate",
      action="store_true",
      help=(
          "Instead of validating existing stubs, write the generated stubs to"
          " disk."
      ),
  )
  ap.add_argument(
      "--strip-generic-slice",
      action="store_true",
      help=(
          "Strip type parameters from `slice`.  Recent versions of mypy support"
          " generic slice but pytype does not currently."
      ),
  )
  args = ap.parse_args()
  output_dir = args.output_dir
  if args.generate and output_dir is None:
    output_dir = pathlib.Path(os.path.dirname(os.path.realpath(__file__)))
    if not os.path.exists(os.path.join(output_dir, os.path.basename(__file__))):
      raise ValueError(f"Invalid source directory {output_dir}")
  mismatch: list[str] = []
  with tempfile.TemporaryDirectory() as tempdir:
    pybind11_stubgen.main([
        "tensorstore",
        "-o",
        tempdir,
        "--ignore-unresolved-names",
        r"^_GlobalPicklableFunction$|^_abc\.",
    ])
    stub_dir = pathlib.Path(tempdir)

    tmp_outputs = [
        path
        for path in stub_dir.rglob("**/*.pyi")
        if path.name != "_tensorstore.pyi"
    ]
    submodule_names = set(path.stem for path in tmp_outputs)
    for input_path in tmp_outputs:
      name = input_path.name
      content = _munge_type_stubs_file(
          input_path,
          strip_generic_slice=args.strip_generic_slice,
          submodule_names=submodule_names,
      )
      if not args.generate:
        existing_content = (
            importlib.resources.files("tensorstore") / name
        ).read_text("utf-8")
        if existing_content != content:
          mismatch.append(name)
          print(f"Stubs for {name} out of date.", file=sys.stderr)
          print(
              "\n".join(
                  difflib.unified_diff(
                      existing_content.splitlines(keepends=False),
                      content.splitlines(keepends=False),
                  )
              ),
              file=sys.stderr,
          )
      else:
        output_path = args.output_dir / name
        output_path.write_text(content, encoding="utf-8")
  if mismatch:
    print(
        f"Mismatch in stub files: {mismatch}\n"
        "To regenerate: "
        "bazel run //python/tensorstore:generate_type_stubs -- "
        "--generate",
        file=sys.stderr,
    )
    sys.exit(1)

if __name__ == "__main__":
  main()
